TidyModels is the newer version of Max Kuhn’s CARET and can be used for a number of machine learning tasks. This modelling framework takes a different approach to modelling - allowing for a more structured workflow, and like tidyverse, has a whole set of packages for making the machine learning process easier. I will touch on a number of these packages in the following sub sections.
This package supercedes that in R for Data Science, as Hadley Wickham admitted he needed a better modelling solution at the time, and Max Kuhn and team have delivered on this.
The aim of this webinar is to:
The framework of a TidyModels approach flows as so:
I will show you the steps in the following tutorials.
I will load in the stranded patient data - a stranded patient is a patient that has been in hospital for longer than 7 days and we also call these Long Waiters. The import steps are below and use the native readr package to load this in:
# Read in the data
strand_pat <- read_csv("Data/Stranded_Data.csv") %>%
setNames(c("stranded_class", "age", "care_home_ref_flag", "medically_safe_flag",
"hcop_flag", "needs_mental_health_support_flag", "previous_care_in_last_12_month", "admit_date", "frail_descrip")) %>%
mutate(stranded_class = factor(stranded_class)) %>%
drop_na()##
## -- Column specification --------------------------------------------------------
## cols(
## Stranded.label = col_character(),
## Age = col_double(),
## Care.home.referral = col_double(),
## MedicallySafe = col_double(),
## HCOP = col_double(),
## Mental_Health_Care = col_double(),
## Periods_of_previous_care = col_double(),
## admit_date = col_character(),
## frailty_index = col_character()
## )
## # A tibble: 6 x 9
## stranded_class age care_home_ref_flag medically_safe_flag hcop_flag
## <fct> <dbl> <dbl> <dbl> <dbl>
## 1 Not Stranded 50 0 0 0
## 2 Not Stranded 31 1 0 1
## 3 Not Stranded 32 0 1 0
## 4 Not Stranded 69 1 1 0
## 5 Not Stranded 33 0 0 1
## 6 Stranded 75 1 1 0
## # ... with 4 more variables: needs_mental_health_support_flag <dbl>,
## # previous_care_in_last_12_month <dbl>, admit_date <chr>, frail_descrip <chr>
As this is a classification problem we need to look at the classification imbalance in the predictor variable i.e. the thing we are trying to predict.
The following code looks at the class imbalance as a volume and proportion and then I am going to use the second index from the class balance table i.e. the number of people who are long waiters is going to be lower than those that aren’t, otherwise we are offering a very poor service to patients.
class_bal_table <- table(strand_pat$stranded_class)
prop_tab <- prop.table(class_bal_table)
upsample_ratio <- class_bal_table[2] / sum(class_bal_table)
print(prop_tab)##
## Not Stranded Stranded
## 0.6552217 0.3447783
##
## Not Stranded Stranded
## 458 241
## Stranded
## 0.3447783
It is always a good idea to observe the data structures of the data items we are trying to predict. I generally separate the names of the variables out into factors, integer / numerics and character vectors:
strand_pat$admit_date <- as.Date(strand_pat$admit_date, format="%d/%m/%Y") #Format date to be date to work with recipes steps
factors <- names(select_if(strand_pat, is.factor))
numbers <- names(select_if(strand_pat, is.numeric))
characters <- names(select_if(strand_pat, is.character))
print(factors); print(numbers); print(characters)## [1] "stranded_class"
## [1] "age" "care_home_ref_flag"
## [3] "medically_safe_flag" "hcop_flag"
## [5] "needs_mental_health_support_flag" "previous_care_in_last_12_month"
## [1] "frail_descrip"
The Rsample package makes it easy to divide your data up. To view all the functionality navigate to the Rsample vignette.
We will divide the data into a training and test sample. This approach is the simplest method to testing your models accuracy and future performance on unseen data. Here we are going to treat the test data as the unseen data to allow us to evaluate if the model is fit for being released into the wild, or not.
Recipes is an excellent package. I have for years done feature, dummy and other types of coding and feature selection with CARET, also a great package, but this makes the process much simpiler. The first part of the recipe is to fit your model and then you add recipe steps, this is supposed to replicate baking adding the specific ingredients. For all the particular steps that recipes contains, go directly to the recipes site.
stranded_rec <-
recipe(stranded_class ~ ., data=train_data) %>%
# The stranded class is what we are trying to predict and we are using the training data
step_date(admit_date, features = c("dow", "month")) %>%
#Recipes step_date allows for additional features to be created from the date
step_rm(admit_date) %>%
#Remove the date, as we have created features off of it, if left in the dreaded multicolinearity may be present
themis::step_upsample(stranded_class, over_ratio = as.numeric(upsample_ratio)) %>%
#SMOTE recipe step to upsample the minority class i.e. stranded patients
step_dummy(all_nominal(), -all_outcomes()) %>%
#Automatically created dummy variables for all categorical variables (nominal)
step_zv(all_predictors()) %>%
#Get rid of features that have zero variance
step_normalize(all_predictors()) #ML models train better when the data is centered and scaled
print(stranded_rec) #Terminology is to use recipe## Data Recipe
##
## Inputs:
##
## role #variables
## outcome 1
## predictor 8
##
## Operations:
##
## Date features from admit_date
## Delete terms admit_date
## Up-sampling based on stranded_class
## Dummy variables from all_nominal(), -all_outcomes()
## Zero variance filter on all_predictors()
## Centering and scaling for all_predictors()
To look up some of these steps, I have previously covered them in a CARET tutorial. For all the list of recipes steps refer to the link above the code chunk.
The package Parsnip is the model to work with TidyModels. Parsnip still does not have many of the algorithms present in CARET, but it makes it much simpler to work in the tidy way.
Here we will create a basic logistic regression as our baseline model. If you want a second tutorial around model ensembling in TidyModels with Baguette and Stacks, then I would be happy to arrange this, but these are a session in themselves.
The reason Logistic Regression is the choice as it is a nice generalised linear model that most people have encountered.
TidyModels has a workflow structure which we will build in the next few steps:
In TidyModels you have to create an instance of the model in memory before working with it:
## Logistic Regression Model Specification (classification)
##
## Computational engine: glm
The next step is to create the model workflow.
Now it is time to do the workflow to connect the newly instantiated model together:
# Create model workflow
strand_wf <-
workflow() %>%
add_model(lr_mod) %>%
add_recipe(stranded_rec)
print(strand_wf)## == Workflow ====================================================================
## Preprocessor: Recipe
## Model: logistic_reg()
##
## -- Preprocessor ----------------------------------------------------------------
## 6 Recipe Steps
##
## * step_date()
## * step_rm()
## * step_upsample()
## * step_dummy()
## * step_zv()
## * step_normalize()
##
## -- Model -----------------------------------------------------------------------
## Logistic Regression Model Specification (classification)
##
## Computational engine: glm
The next step is fitting the model to our data:
The final step is to use the pull_workflow_fit() parameter to retrieve the fit on the workflow:
## # A tibble: 18 x 5
## term estimate std.error statistic p.value
## <chr> <dbl> <dbl> <dbl> <dbl>
## 1 (Intercept) -0.242 0.172 -1.41 1.60e- 1
## 2 age 0.296 0.261 1.13 2.58e- 1
## 3 care_home_ref_flag 0.204 0.115 1.78 7.57e- 2
## 4 medically_safe_flag -0.173 0.120 -1.45 1.48e- 1
## 5 hcop_flag -0.0443 0.114 -0.390 6.97e- 1
## 6 needs_mental_health_support_flag 0.0646 0.116 0.558 5.77e- 1
## 7 previous_care_in_last_12_month 2.98 0.477 6.24 4.48e-10
## 8 frail_descrip_Fall.patient.history -0.185 0.150 -1.23 2.18e- 1
## 9 frail_descrip_Mobility.problems 0.0864 0.138 0.625 5.32e- 1
## 10 frail_descrip_No.index.item 0.144 0.279 0.517 6.05e- 1
## 11 admit_date_dow_Mon -0.149 0.169 -0.884 3.77e- 1
## 12 admit_date_dow_Tue 0.0955 0.154 0.620 5.36e- 1
## 13 admit_date_dow_Wed 0.232 0.162 1.44 1.51e- 1
## 14 admit_date_dow_Thu 0.175 0.147 1.19 2.34e- 1
## 15 admit_date_dow_Fri 0.0203 0.165 0.123 9.02e- 1
## 16 admit_date_dow_Sat 0.181 0.150 1.20 2.29e- 1
## 17 admit_date_month_Feb 0.0144 0.133 0.108 9.14e- 1
## 18 admit_date_month_Dec 0.00950 0.122 0.0778 9.38e- 1
As an optional step I have created a plot to visualise the significance. This will only work with linear, and generalized linear models, that analyse p values from t tests and finding the probability value from the t distribution. The visualisation code is contained hereunder:
# Add significance column to tibble using mutate
strand_fitted <- strand_fitted %>%
mutate(Significance = ifelse(p.value < 0.05, "Significant", "Insignificant")) %>%
arrange(desc(p.value))
#Create a ggplot object to visualise significance
plot <- strand_fitted %>%
ggplot(data = strand_fitted, mapping = aes(x=term, y=p.value, fill=Significance)) +
geom_col() + theme(axis.text.x = element_text(
face="bold", color="#0070BA",
size=8, angle=90)
) + labs(y="P value", x="Terms",
title="P value significance chart",
subtitle="A chart to represent the significant variables in the model",
caption="Produced by Gary Hutson")
#print("Creating plot of P values")
#print(plot)
plotly::ggplotly(plot)Now we will assess how well the model predicts on the test (holdout) data to evaluate if we want to productionise the model, or abandon it at this stage. This is implemented below:
class_pred <- predict(strand_fit, test_data) #Get the class label predictions
prob_pred <- predict(strand_fit, test_data, type="prob") #Get the probability predictions
lr_predictions <- data.frame(class_pred, prob_pred) %>%
setNames(c("LR_Class", "LR_NotStrandedProb", "LR_StrandedProb")) #Combined into tibble and rename
stranded_preds <- test_data %>%
bind_cols(lr_predictions)
print(tail(lr_predictions))## LR_Class LR_NotStrandedProb LR_StrandedProb
## 169 Not Stranded 0.8280968 0.1719032
## 170 Not Stranded 0.7650720 0.2349280
## 171 Not Stranded 0.8256528 0.1743472
## 172 Not Stranded 0.8405598 0.1594402
## 173 Not Stranded 0.8371738 0.1628262
## 174 Not Stranded 0.8501583 0.1498417
Yardstick is another tool in the TidyModels arsenal. It is useful for generating quick summary statistics and evaluation metrics. I will grab the area under the curve estimates to show how well the model fits:
roc_plot <-
stranded_preds %>%
roc_curve(truth = stranded_class, LR_NotStrandedProb) %>%
autoplot
print(roc_plot)I like ROC plots - but they only show you sensitivity how well it is at predicting stranded and the inverse how good it is at predicting not stranded. I like to look at the overall accuracy and balanced accuracy on a confusion matrix, for binomial classification problems.
I use the CARET package and utilise the confusion matrix functions to perform this:
## Loading required package: lattice
##
## Attaching package: 'caret'
## The following objects are masked from 'package:yardstick':
##
## precision, recall, sensitivity, specificity
## The following object is masked from 'package:purrr':
##
## lift
cm <- caret::confusionMatrix(stranded_preds$stranded_class,
stranded_preds$LR_Class,
positive="Stranded")
print(cm)## Confusion Matrix and Statistics
##
## Reference
## Prediction Not Stranded Stranded
## Not Stranded 106 2
## Stranded 35 31
##
## Accuracy : 0.7874
## 95% CI : (0.719, 0.8456)
## No Information Rate : 0.8103
## P-Value [Acc > NIR] : 0.8093
##
## Kappa : 0.4998
##
## Mcnemar's Test P-Value : 1.435e-07
##
## Sensitivity : 0.9394
## Specificity : 0.7518
## Pos Pred Value : 0.4697
## Neg Pred Value : 0.9815
## Prevalence : 0.1897
## Detection Rate : 0.1782
## Detection Prevalence : 0.3793
## Balanced Accuracy : 0.8456
##
## 'Positive' Class : Stranded
##
On the back of the Advanced Modelling course I did for the NHS-R Community I have created a package to work with the outputs of a confusion matrix. This package is aimed at the flattening of binary and multi-class confusion matrix results.
To load in the package you need to use the remotes package and bring in the ConfusionTableR package, which is available from my GitHub site.
#Load in my ConfusionTableR package to visualise this
#remotes::install_github("https://github.com/StatsGary/ConfusionTableR") #Use remotes package to install the package
#from GitHub r
library(ConfusionTableR)
cm_plot <- ConfusionTableR::binary_visualiseR(cm, class_label1 = "Not Stranded",
class_label2 = "Stranded",
quadrant_col1 = "#53BFD3", quadrant_col2 = "#006838",
text_col = "white", custom_title = "Stranded patient Confusion Matrix")# Flatten to store in database
#Stored confusion matrix
cm_results <- ConfusionTableR::binary_class_cm(cm)
print(cm_results)## Pred_Not.Stranded_Ref_Not.Stranded Pred_Stranded_Ref_Not.Stranded
## 1 106 35
## Pred_Not.Stranded_Ref_Stranded Pred_Stranded_Ref_Stranded Accuracy Kappa
## 1 2 31 0.7873563 0.4997669
## AccuracyLower AccuracyUpper AccuracyNull AccuracyPValue McnemarPValue
## 1 0.7190153 0.8456409 0.8103448 0.8092972 1.434553e-07
## Sensitivity Specificity Pos.Pred.Value Neg.Pred.Value Precision Recall
## 1 0.9393939 0.751773 0.469697 0.9814815 0.469697 0.9393939
## F1 Prevalence Detection.Rate Detection.Prevalence Balanced.Accuracy
## 1 0.6262626 0.1896552 0.1781609 0.3793103 0.8455835
## cm_ts
## 1 2021-02-23 15:28:08
The next markdown document will look at how to improve your models with model selection, K-fold cross validation and hyperparameter tuning. I was thinking of doing an ensembling course off the back of this, so please contact me if that would be interesting to you.